<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>CIFAR-10 CNN - Keras 中文文档</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
  
  <script>
    // Current page data
    var mkdocs_page_name = "CIFAR-10 CNN";
    var mkdocs_page_input_path = "examples/cifar10_cnn.md";
    var mkdocs_page_url = "/zh/examples/cifar10_cnn/";
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
  <script>
      (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
      })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

      ga('create', 'UA-61785484-1', 'keras.io');
      ga('send', 'pageview');
  </script>
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> Keras 中文文档</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">主页</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../why-use-keras/">为什么选择 Keras?</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">快速开始</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../getting-started/sequential-model-guide/">Sequential 顺序模型指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/functional-api-guide/">函数式 API 指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/faq/">FAQ 常见问题解答</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">模型</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../models/about-keras-models/">关于 Keras 模型</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/sequential/">Sequential 顺序模型 API</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/model/">函数式 API</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">Layers</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../layers/about-keras-layers/">关于 Keras 网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/core/">核心网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/convolutional/">卷积层 Convolutional</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/pooling/">池化层 Pooling</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/local/">局部连接层 Locally-connected</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/recurrent/">循环层 Recurrent</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/embeddings/">嵌入层 Embedding</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/merge/">融合层 Merge</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/advanced-activations/">高级激活层 Advanced Activations</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/normalization/">标准化层 Normalization</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/noise/">噪声层 Noise</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/wrappers/">层封装器 wrappers</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/writing-your-own-keras-layers/">编写你自己的层</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">数据预处理</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../preprocessing/sequence/">序列预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/text/">文本预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/image/">图像预处理</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../losses/">损失函数 Losses</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../metrics/">评估标准 Metrics</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../optimizers/">优化器 Optimizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../activations/">激活函数 Activations</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../callbacks/">回调函数 Callbacks</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../datasets/">常用数据集 Datasets</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../applications/">应用 Applications</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../backend/">后端 Backend</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../initializers/">初始化 Initializers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../regularizers/">正则化 Regularizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../constraints/">约束 Constraints</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../visualization/">可视化 Visualization</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../scikit-learn-api/">Scikit-learn API</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../utils/">工具</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../contributing/">贡献</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">经典样例</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../addition_rnn/">Addition RNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../babi_rnn/">Baby RNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../babi_memnn/">Baby MemNN</a>
                </li>
                <li class=" current">
                    
    <a class="current" href="./">CIFAR-10 CNN</a>
    <ul class="subnav">
            
    <li class="toctree-l3"><a href="#cifar10">在 CIFAR10 小型图像数据集上训练一个深度卷积神经网络。</a></li>
    

    </ul>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_capsule/">CIFAR-10 CNN-Capsule</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_tfaugment2d/">CIFAR-10 CNN with augmentation (TF)</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_resnet/">CIFAR-10 ResNet</a>
                </li>
                <li class="">
                    
    <a class="" href="../conv_filter_visualization/">Convolution filter visualization</a>
                </li>
                <li class="">
                    
    <a class="" href="../image_ocr/">Image OCR</a>
                </li>
                <li class="">
                    
    <a class="" href="../imdb_bidirectional_lstm/">Bidirectional LSTM</a>
                </li>
    </ul>
	    </li>
          
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">Keras 中文文档</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>经典样例 &raquo;</li>
        
      
    
    <li>CIFAR-10 CNN</li>
    <li class="wy-breadcrumbs-aside">
      
        <a href="https://github.com/keras-team/keras-docs-zh/edit/master/docs/examples/cifar10_cnn.md"
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
                <h1 id="cifar10">在 CIFAR10 小型图像数据集上训练一个深度卷积神经网络。</h1>
<p>在 25 轮迭代后 验证集准确率达到 75%，在 50 轮后达到 79%。
(尽管目前仍然欠拟合)。</p>
<pre><code class="python">from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import os

batch_size = 32
num_classes = 10
epochs = 100
data_augmentation = True
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

# 数据，切分为训练和测试集。
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

# 初始化 RMSprop 优化器。
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)

# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              shuffle=True)
else:
    print('Using real-time data augmentation.')
    # 这一步将进行数据处理和实时数据增益。data augmentation:
    datagen = ImageDataGenerator(
        featurewise_center=False,  # 将整个数据集的均值设为0
        samplewise_center=False,  # 将每个样本的均值设为0
        featurewise_std_normalization=False,  # 将输入除以整个数据集的标准差
        samplewise_std_normalization=False,  # 将输入除以其标准差
        zca_whitening=False,  # 运用 ZCA 白化
        zca_epsilon=1e-06,  # ZCA 白化的 epsilon值
        rotation_range=0,  # 随机旋转图像范围 (角度, 0 to 180)
        # 随机水平移动图像 (总宽度的百分比)
        width_shift_range=0.1,
        # 随机垂直移动图像 (总高度的百分比)
        height_shift_range=0.1,
        shear_range=0.,  # 设置随机裁剪范围
        zoom_range=0.,  # 设置随机放大范围
        channel_shift_range=0.,  # 设置随机通道切换的范围
        # 设置填充输入边界之外的点的模式
        fill_mode='nearest',
        cval=0.,  # 在 fill_mode = &quot;constant&quot; 时使用的值
        horizontal_flip=True,  # 随机水平翻转图像
        vertical_flip=False,  # 随机垂直翻转图像
        # 设置缩放因子 (在其他转换之前使用)
        rescale=None,
        # 设置将应用于每一个输入的函数
        preprocessing_function=None,
        # 图像数据格式，&quot;channels_first&quot; 或 &quot;channels_last&quot; 之一
        data_format=None,
        # 保留用于验证的图像比例（严格在0和1之间）
        validation_split=0.0)

    # 计算特征标准化所需的计算量
    # (如果应用 ZCA 白化，则为 std，mean和主成分).
    datagen.fit(x_train)

    # 利用由 datagen.flow() 生成的批来训练模型
    model.fit_generator(datagen.flow(x_train, y_train,
                                     batch_size=batch_size),
                        epochs=epochs,
                        validation_data=(x_test, y_test),
                        workers=4)

# 保存模型和权重
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)

# 评估训练模型
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
</code></pre>
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="../cifar10_cnn_capsule/" class="btn btn-neutral float-right" title="CIFAR-10 CNN-Capsule">Next <span class="icon icon-circle-arrow-right"></span></a>
      
      
        <a href="../babi_memnn/" class="btn btn-neutral" title="Baby MemNN"><span class="icon icon-circle-arrow-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/keras-team/keras-docs-zh/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../babi_memnn/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
        <span style="margin-left: 15px"><a href="../cifar10_cnn_capsule/" style="color: #fcfcfc">Next &raquo;</a></span>
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../search/main.js" defer></script>

</body>
</html>
